# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DenseQuant."""
from __future__ import absolute_import

from mindspore.ops.primitive import Primitive
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import QuantDtype
from mindspore.nn.layer.activation import get_activation
from mindspore.nn.cell import Cell
from mindspore.nn.layer.basic import Dense
from mindspore_gs.validator import Validator
from .fake_quant_with_min_max_observer import quant_config_default, QuantConfig


class DenseQuant(Cell):
    r"""
    The fully connected layer with fake quantized operation.

    This part is a more detailed overview of Dense operation. For more details about Quantization,
    please refer to the implementation of class of `FakeQuantWithMinMaxObserver`,
    :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.

    Args:
        in_channels (int): The dimension of the input space.
        out_channels (int): The dimension of the output space.
        weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
            is same as `x`. The values of str refer to the function `initializer`. Default: 'normal'.
        bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
            same as `x`. The values of str refer to the function `initializer`. Default: 'zeros'.
        has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
        activation (Union[str, Cell, Primitive]): The regularization function applied to the output of the layer,
            eg. 'relu'. Default: None.
        quant_config (QuantConfig): Configures the types of quant observer and quant settings of weight and
            activation. Note that, QuantConfig is a special namedtuple, which is designed for quantization
            and can be generated by :func:`mindspore.compression.quant.create_quant_config` method.
            Default: QuantConfig with both items set to default :class:`FakeQuantWithMinMaxObserver`.
        quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.

    Inputs:
        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
          The input dimension is preferably 2D or 4D.

    Outputs:
        Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

    Raises:
        TypeError: If `in_channels`, `out_channels` is not an int.
        TypeError: If `has_bias` is not a bool.
        TypeError: If `activation` is not str, Cell and Primitive.
        ValueError: If `in_channels` or `out_channels` is less than 1.
        ValueError: If the dims of `weight_init` is not equal to 2 or the first element of `weight_init` is not equal
            to `out_channels` or the second element of `weight_init` is not equal to `in_channels`.
        ValueError: If the dims of `bias_init` is not equal to 1 or the element of `bias_init` is not equal
            to `out_channels`.

    Supported Platforms:
        ``Ascend`` ``GPU``

    Examples:
        >>> import numpy as np
        >>> import mindspore
        >>> from mindspore.compression import quant
        >>> from mindspore import Tensor, nn
        >>> qconfig = quant.create_quant_config()
        >>> dense_quant = nn.DenseQuant(2, 1, weight_init='ones', quant_config=qconfig)
        >>> x = Tensor(np.array([[1, 5], [3, 4]]), mindspore.float32)
        >>> result = dense_quant(x)
        >>> print(result)
        [[5.929413]
         [6.9176483]]
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 weight_init='normal',
                 bias_init='zeros',
                 has_bias=True,
                 activation=None,
                 quant_config=quant_config_default,
                 quant_dtype=QuantDtype.INT8):
        """Initialize DenseQuant."""
        super(DenseQuant, self).__init__()
        self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
        self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
        self.has_bias = Validator.check_bool(has_bias, "has_bias", self.cls_name)
        _ = quant_dtype  # for fix pylint unused-argument

        if isinstance(weight_init, Tensor):
            if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
                    weight_init.shape[1] != in_channels:
                raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' should "
                                 f"be equal to 2, and the first dim must be equal to 'out_channels', and the "
                                 f"second dim must be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
                                 f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")

        self.weight = Parameter(initializer(
            weight_init, [out_channels, in_channels]), name="weight")

        if self.has_bias:
            if isinstance(bias_init, Tensor):
                if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
                    raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
                                     f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
                                     f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")

            self.bias = Parameter(initializer(
                bias_init, [out_channels]), name="bias")

        self.matmul = P.MatMul(transpose_b=True)
        self.bias_add = P.BiasAdd()

        self.activation = get_activation(activation) if isinstance(activation, str) else activation
        if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
            raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, "
                            f"but got {activation}.")

        self.activation_flag = self.activation is not None
        self.fake_quant_weight = quant_config.weight(channel_axis=0,
                                                     num_channels=out_channels)

    @classmethod
    def from_dense(cls, dense: Dense, quant_config: QuantConfig):
        """
        A class method to create `DenseQuant` from a `Dense`
        """
        dense_quant = cls(in_channels=dense.in_channels,
                          out_channels=dense.out_channels,
                          weight_init=dense.weight,
                          bias_init=dense.bias,
                          has_bias=dense.has_bias,
                          activation=dense.activation,
                          quant_config=quant_config)
        dense_quant.weight = dense.weight
        if dense.has_bias:
            dense_quant.bias = dense.bias
        return dense_quant

    def construct(self, x):
        """Use operators to construct the Dense layer.

        Args:
            x (Tensor): Input tensor.
        """
        output = self.fake_quant_weight(self.weight)
        output = self.matmul(x, output)
        if self.has_bias:
            output = self.bias_add(output, self.bias)
        if self.activation_flag:
            return self.activation(output)
        return output

    def extend_repr(self):
        """A pretty print for Dense layer."""
        s = 'in_channels={}, out_channels={}, weight={}, has_bias={}'.format(
            self.in_channels, self.out_channels, self.weight, self.has_bias)
        if self.has_bias:
            s += ', bias={}'.format(self.bias)
        if self.activation_flag:
            s += ', activation={}'.format(self.activation)
        return s
